import random
from random import randrange
import pandas as pd
from sklearn.model_selection import train_test_split
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainerCallback, TrainerState, TrainerControl
from datasets import load_dataset, DatasetDict, Dataset
import pickle
from functools import partial
from tqdm import tqdm
from trl import SFTConfig, SFTTrainer
from nltk.translate.bleu_score import sentence_bleu
from rouge_score import rouge_scorer
import numpy as np
from peft import LoraConfig
import re
import argparse
from transformers import DataCollatorForSeq2Seq
import warnings
from openai import OpenAI
import time

# Ignore all warnings
warnings.filterwarnings("ignore")

def format_instruction(sample):
    text = sample['response']

    action_names = stat[sample['key']]['action_names']
    apis_desc = [identifier2python.get(action_name.replace('.', '_')) for action_name in action_names]
    apis_desc = [_ for _ in apis_desc if _ is not None]
    apis_desc = "\n".join(apis_desc)

    if sample.get('query') is not None:
        query = sample.get('query')
    else:
        query = text.split('step_by_step_description')[0].replace("'", "").replace(':', "").replace('{','').replace('}','').replace('query', '').strip()
        if len(query) == 0:
            return None
    if sample.get('description') is not None:
        step_by_step_desc = sample.get('description')
    else:
        step_by_step_desc = text.split('step_by_step_description')[1].replace("'", "").replace(':', "").replace('{','').replace('}','').strip()
        if len(step_by_step_desc) == 0:
            return None
    return [query, apis_desc, step_by_step_desc]



def prompt_fn(input, model_name):
    client = OpenAI(
        api_key='',
        base_url="",
    )

    messages = [{"role": "user", "content": input}]
    response = client.chat.completions.create(
        model=model_name,
        messages=messages,
        temperature=0.2,
        top_p=1.0,
        n=1,
        stream=False,
        frequency_penalty=0.0,
        presence_penalty=0.0,
        logit_bias={}
    ).model_dump()

    return response['choices'][0]['message']['content']


def process_sample(examples, args):
#     prompt_template =
# """
# Assuming you are an expert at taking complex tasks apart, consider the back-and-forth connections between steps to step-by-step through a complex task.
# Here are some examples:
# {}
# Please generate an affordable care act news in NYT following the requirements below:
# 1. Should focus on role of state governments;
# 2. Should be in length between 30 and 80 words;
# 3. The writing style of the news should be news
# analysis;
# 4. The location of the news is in Oceania.
# """

    final_prompt = f"""
Assuming you are an expert at using different apis to solve a complex problem. You can generate a query that can be solved with the following apis.
Here are some apis and corresponding query:
{examples}
Please generate a query using the following apis:
{apis}
The following criterias should be followed:
1. Should use all the above apis;
2. The queries should be broad enough to be simple.
query:
"""

    for _ in range(20):
        try:
            response = prompt_fn(final_prompt, args.llm)
            sample['response'] = response
            return sample
        except Exception:
            print("Rate limit exceeded. Retrying in 30 seconds...")
            time.sleep(30)

    # 如果重试3次仍失败，返回None
    return None





# 加载数据
with open('../../data/statistics.pkl', 'rb') as fp:
    stat = pickle.load(fp)
with open('../../data/identifier2python.pkl', 'rb') as fp:
    identifier2python = pickle.load(fp)


data = load_dataset('json', data_files='query_and_description.json')['train']
data = [sample for sample in data if sample['key'] in stat.keys()]
# ==================检查序列长度===============

detailed_data = []
for sample in data:
    detailed_data.append(format_instruction(sample))

ICL_number = 3
apis_number = 20
ALL_API_DESC = list(identifier2python.values())

for i in range(10):
    ICL_context = ""
    ICL_examples = random.sample(detailed_data, ICL_number)

    for _ in range(ICL_number):
        ICL_context += f"{_}.\n\napis:\n{ICL_examples[_][1]}\nquery:{ICL_examples[_][0]}"

    apis = '\n\n'.join(random.sample(ALL_API_DESC, apis_number))

    query = process_sample(ICL_context, apis)



print('dada')



